{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, \"/home/husein/parsing/self-attentive-parser/src\")\n",
    "sys.path.append(\"/home/husein/parsing/self-attentive-parser\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from transformers import AlbertTokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AlbertTokenizer\n",
    "tokenizer = AlbertTokenizer.from_pretrained(\n",
    "    'huseinzol05/albert-base-bahasa-cased',\n",
    "    do_lower_case = False,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "with open('vocab-albert-base.json') as fopen:\n",
    "    data = json.load(fopen)\n",
    "    \n",
    "LABEL_VOCAB = data['label']\n",
    "TAG_VOCAB = data['tag']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.gfile.GFile('export/albert-base.pb', 'rb') as f:\n",
    "    graph_def = tf.GraphDef()\n",
    "    graph_def.ParseFromString(f.read())\n",
    "    with tf.Graph().as_default() as graph:\n",
    "        tf.import_graph_def(graph_def)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_ids = graph.get_tensor_by_name('import/input_ids:0')\n",
    "word_end_mask = graph.get_tensor_by_name('import/word_end_mask:0')\n",
    "charts = graph.get_tensor_by_name('import/charts:0')\n",
    "tags = graph.get_tensor_by_name('import/tags:0')\n",
    "sess = tf.InteractiveSession(graph = graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "BERT_MAX_LEN = 512\n",
    "import numpy as np\n",
    "from parse_nk import BERT_TOKEN_MAPPING\n",
    "\n",
    "def make_feed_dict_bert(sentences):\n",
    "    all_input_ids = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)\n",
    "    all_word_end_mask = np.zeros((len(sentences), BERT_MAX_LEN), dtype=int)\n",
    "    \n",
    "\n",
    "    subword_max_len = 0\n",
    "    for snum, sentence in enumerate(sentences):\n",
    "        tokens = []\n",
    "        word_end_mask = []\n",
    "\n",
    "        tokens.append(u\"[CLS]\")\n",
    "        word_end_mask.append(1)\n",
    "\n",
    "        cleaned_words = []\n",
    "        for word in sentence:\n",
    "            word = BERT_TOKEN_MAPPING.get(word, word)\n",
    "            # BERT is pre-trained with a tokenizer that doesn't split off\n",
    "            # n't as its own token\n",
    "            if word == u\"n't\" and cleaned_words:\n",
    "                cleaned_words[-1] = cleaned_words[-1] + u\"n\"\n",
    "                word = u\"'t\"\n",
    "            cleaned_words.append(word)\n",
    "\n",
    "        for word in cleaned_words:\n",
    "            word_tokens = tokenizer.tokenize(word)\n",
    "            if not word_tokens:\n",
    "                # The tokenizer used in conjunction with the parser may not\n",
    "                # align with BERT; in particular spaCy will create separate\n",
    "                # tokens for whitespace when there is more than one space in\n",
    "                # a row, and will sometimes separate out characters of\n",
    "                # unicode category Mn (which BERT strips when do_lower_case\n",
    "                # is enabled). Substituting UNK is not strictly correct, but\n",
    "                # it's better than failing to return a valid parse.\n",
    "                word_tokens = [\"[UNK]\"]\n",
    "            for _ in range(len(word_tokens)):\n",
    "                word_end_mask.append(0)\n",
    "            word_end_mask[-1] = 1\n",
    "            tokens.extend(word_tokens)\n",
    "        tokens.append(u\"[SEP]\")\n",
    "        word_end_mask.append(1)\n",
    "\n",
    "        input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "        subword_max_len = max(subword_max_len, len(input_ids))\n",
    "\n",
    "        all_input_ids[snum, :len(input_ids)] = input_ids\n",
    "        all_word_end_mask[snum, :len(word_end_mask)] = word_end_mask\n",
    "\n",
    "    all_input_ids = all_input_ids[:, :subword_max_len]\n",
    "    all_word_end_mask = all_word_end_mask[:, :subword_max_len]\n",
    "    return all_input_ids, all_word_end_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[   2,  280,  457, 1520,  597,  454, 3794,    3]]),\n",
       " array([[1, 1, 1, 1, 1, 1, 1, 1]]))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "s = 'Saya sedang membaca buku tentang Perlembagaan'.split()\n",
    "sentences = [s]\n",
    "i, m = make_feed_dict_bert(sentences)\n",
    "i, m"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[[[ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,\n",
       "           -2.3575218 , -2.1976795 ],\n",
       "          [ 0.        , -1.9336585 , -2.5753968 , ..., -1.7595727 ,\n",
       "           -1.7113284 , -1.9766747 ],\n",
       "          [ 0.        , -1.8684319 , -2.5347548 , ..., -2.1521864 ,\n",
       "           -1.9502559 , -2.4117303 ],\n",
       "          ...,\n",
       "          [ 0.        , -2.4688456 , -5.0541735 , ..., -2.1501265 ,\n",
       "           -2.7817178 , -2.045808  ],\n",
       "          [ 0.        ,  2.0402954 , -1.3372381 , ..., -2.0561574 ,\n",
       "           -2.2189684 , -2.2715192 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -3.761474  , -2.1691635 , ..., -2.5855842 ,\n",
       "           -2.1737652 , -2.256596  ],\n",
       "          [ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,\n",
       "           -2.3575218 , -2.1976795 ],\n",
       "          [ 0.        , -3.0972695 , -1.830802  , ..., -2.2955809 ,\n",
       "           -2.1975942 , -2.3570776 ],\n",
       "          ...,\n",
       "          [ 0.        , -3.089801  , -4.995689  , ..., -2.4817433 ,\n",
       "           -3.064287  , -2.1675904 ],\n",
       "          [ 0.        , -1.3507465 , -1.708786  , ..., -2.7811885 ,\n",
       "           -2.4774568 , -2.4355247 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -3.4829345 , -2.622274  , ..., -2.0963447 ,\n",
       "           -1.7853098 , -1.9312341 ],\n",
       "          [ 0.        , -2.702062  , -2.5294354 , ..., -1.1650053 ,\n",
       "           -1.3782667 , -1.4497064 ],\n",
       "          [ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,\n",
       "           -2.3575218 , -2.1976795 ],\n",
       "          ...,\n",
       "          [ 0.        , -2.6063166 , -5.185776  , ..., -2.0100474 ,\n",
       "           -2.6796756 , -2.1071324 ],\n",
       "          [ 0.        , -0.88271576, -2.1417198 , ..., -2.2307074 ,\n",
       "           -2.1729112 , -1.9563714 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[ 0.        , -3.3378694 , -0.7042449 , ..., -1.8816599 ,\n",
       "           -1.2442505 , -2.0501723 ],\n",
       "          [ 0.        , -2.584117  , -1.3452475 , ..., -1.6568699 ,\n",
       "           -1.3097196 , -1.9925401 ],\n",
       "          [ 0.        , -2.2831483 , -0.9533983 , ..., -1.8414954 ,\n",
       "           -1.3157982 , -2.4766014 ],\n",
       "          ...,\n",
       "          [ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,\n",
       "           -2.3575218 , -2.1976795 ],\n",
       "          [ 0.        , -1.3462337 , -0.4044567 , ..., -2.0897412 ,\n",
       "           -1.2382249 , -2.5590096 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        , -4.0408955 , -1.5319583 , ..., -1.5409988 ,\n",
       "           -1.5442361 , -1.4087954 ],\n",
       "          [ 0.        , -4.04403   , -2.177527  , ..., -1.6317112 ,\n",
       "           -1.5265354 , -1.5319704 ],\n",
       "          [ 0.        , -4.3131905 , -2.1354446 , ..., -1.8686312 ,\n",
       "           -1.7331722 , -1.7169375 ],\n",
       "          ...,\n",
       "          [ 0.        , -4.2470737 , -4.897116  , ..., -2.019705  ,\n",
       "           -2.3836515 , -1.9605376 ],\n",
       "          [ 0.        , -4.6400285 , -3.263862  , ..., -2.365652  ,\n",
       "           -2.3575218 , -2.1976795 ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]],\n",
       " \n",
       "         [[ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          ...,\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ],\n",
       "          [ 0.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "            0.        ,  0.        ]]]], dtype=float32),\n",
       " array([[0, 4, 8, 9, 6, 3, 6, 1]]))"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "charts_val, tags_val = sess.run((charts, tags), {input_ids: i, word_end_mask: m})\n",
    "charts_val, tags_val"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "for snum, sentence in enumerate(sentences):\n",
    "    chart_size = len(sentence) + 1\n",
    "    chart = charts_val[snum,:chart_size,:chart_size,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !wget https://raw.githubusercontent.com/michaeljohns2/self-attentive-parser/michaeljohns2-support-tf2-patch/benepar/chart_decoder.pyx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "import chart_decoder_py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(17.99921,\n",
       " array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5]),\n",
       " array([6, 1, 6, 2, 6, 3, 6, 4, 6, 5, 6]),\n",
       " array([1, 4, 5, 0, 5, 0, 3, 3, 2, 0, 3]))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "chart_decoder_py.decode(chart)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk\n",
    "from nltk import Tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "PTB_TOKEN_ESCAPE = {u\"(\": u\"-LRB-\",\n",
    "    u\")\": u\"-RRB-\",\n",
    "    u\"{\": u\"-LCB-\",\n",
    "    u\"}\": u\"-RCB-\",\n",
    "    u\"[\": u\"-LSB-\",\n",
    "    u\"]\": u\"-RSB-\"}\n",
    "\n",
    "\n",
    "def make_nltk_tree(sentence, tags, score, p_i, p_j, p_label):\n",
    "\n",
    "    # Python 2 doesn't support \"nonlocal\", so wrap idx in a list\n",
    "    idx_cell = [-1]\n",
    "    def make_tree():\n",
    "        idx_cell[0] += 1\n",
    "        idx = idx_cell[0]\n",
    "        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]\n",
    "        label = LABEL_VOCAB[label_idx]\n",
    "        if (i + 1) >= j:\n",
    "            word = sentence[i]\n",
    "            tag = TAG_VOCAB[tags[i]]\n",
    "            tag = PTB_TOKEN_ESCAPE.get(tag, tag)\n",
    "            word = PTB_TOKEN_ESCAPE.get(word, word)\n",
    "            tree = Tree(tag, [word])\n",
    "            for sublabel in label[::-1]:\n",
    "                tree = Tree(sublabel, [tree])\n",
    "            return [tree]\n",
    "        else:\n",
    "            left_trees = make_tree()\n",
    "            right_trees = make_tree()\n",
    "            children = left_trees + right_trees\n",
    "            if label:\n",
    "                tree = Tree(label[-1], children)\n",
    "                for sublabel in reversed(label[:-1]):\n",
    "                    tree = Tree(sublabel, [tree])\n",
    "                return [tree]\n",
    "            else:\n",
    "                return children\n",
    "\n",
    "    tree = make_tree()[0]\n",
    "    tree.score = score\n",
    "    return tree"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(S\n",
      "  (NP-SBJ (<START> Saya))\n",
      "  (VP\n",
      "    (PRP sedang)\n",
      "    (VP\n",
      "      (MD membaca)\n",
      "      (NP (NP (VB buku)) (PP (NN tentang) (NP (IN Perlembagaan)))))))\n"
     ]
    }
   ],
   "source": [
    "tree = make_nltk_tree(s, tags_val[0], *chart_decoder_py.decode(chart))\n",
    "print(str(tree))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_str_tree(sentence, tags, score, p_i, p_j, p_label):\n",
    "    idx_cell = [-1]\n",
    "    def make_str():\n",
    "        idx_cell[0] += 1\n",
    "        idx = idx_cell[0]\n",
    "        i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]\n",
    "        label = LABEL_VOCAB[label_idx]\n",
    "        if (i + 1) >= j:\n",
    "            word = sentence[i]\n",
    "            tag = TAG_VOCAB[tags[i]]\n",
    "            tag = PTB_TOKEN_ESCAPE.get(tag, tag)\n",
    "            word = PTB_TOKEN_ESCAPE.get(word, word)\n",
    "            s = u\"({} {})\".format(tag, word)\n",
    "        else:\n",
    "            children = []\n",
    "            while ((idx_cell[0] + 1) < len(p_i)\n",
    "                and i <= p_i[idx_cell[0] + 1]\n",
    "                and p_j[idx_cell[0] + 1] <= j):\n",
    "                children.append(make_str())\n",
    "\n",
    "            s = u\" \".join(children)\n",
    "            \n",
    "        for sublabel in reversed(label):\n",
    "            s = u\"({} {})\".format(sublabel, s)\n",
    "        return s\n",
    "    return make_str()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'(S (NP-SBJ (<START> Saya)) (VP (PRP sedang) (VP (MD membaca) (NP (NP (VB buku)) (PP (NN tentang) (NP (IN Perlembagaan)))))))'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "make_str_tree(s, tags_val[0], *chart_decoder_py.decode(chart))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
